-
Notifications
You must be signed in to change notification settings - Fork 66
Improve Typing of ModelInterface #215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Poem
Note 🔌 MCP (Model Context Protocol) integration is now available in Early Access!Pro users can now connect to remote MCP servers under the Integrations page to get reviews and chat conversations that understand additional development context. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
torch_sim/workflows/a2c.py(2 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (1)
torch_sim/workflows/a2c.py (1)
torch_sim/optimizers.py (1)
FireState(441-482)
🪛 GitHub Check: codecov/patch
torch_sim/workflows/a2c.py
[warning] 825-825: torch_sim/workflows/a2c.py#L825
Added line #L825 was not covered by tests
[warning] 827-827: torch_sim/workflows/a2c.py#L827
Added line #L827 was not covered by tests
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (37)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/models/test_mace.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: build-docs
🔇 Additional comments (1)
torch_sim/workflows/a2c.py (1)
825-827: LGTM! Explicit typing improves type safety.The explicit
floattyping forfinal_energyandfinal_pressurehelps the type checker understand the variable types and improves code clarity.
uvx ty checkuvx ty check or uvx mypy torch_sim --strict-equality
uvx ty check or uvx mypy torch_sim --strict-equalityuvx basedpyright
|
For visibility we are probably going to type the entire repo in parts. We'll merge this PR soon or break it up into more understandable parts |
|
|
||
|
|
||
| class ModelInterface(ABC): | ||
| class ModelInterface(torch.nn.Module, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice change
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🔭 Outside diff range comments (14)
torch_sim/models/particle_life.py (2)
193-193: Return type bug: energy stored as Python float instead of torch.Tensorresults["energy"] is set to 0.0 (float), but the return type is dict[str, torch.Tensor] and forward later stacks energies across systems. This will break stacking and typing.
Apply this diff:
- results = {"energy": 0.0} + results = { + "energy": torch.tensor(0.0, dtype=self.dtype, device=self.device) + }
196-199: Incorrect call signature: unexpected keyword 'epsilon' and missing required 'A'asymmetric_particle_pair_force_jit(dr, A, beta, sigma) does not accept epsilon. Passing epsilon=... will fail type-checks and at runtime.
Use A=self.epsilon and keep sigma keyword:
- pair_forces = asymmetric_particle_pair_force_jit( - distances, sigma=self.sigma, epsilon=self.epsilon - ) + pair_forces = asymmetric_particle_pair_force_jit( + distances, + A=self.epsilon, + sigma=self.sigma, + )Optionally pass beta if you intend a non-default inner radius:
+ beta=DEFAULT_BETA,torch_sim/models/fairchem.py (1)
324-341: Alignforwardreturn type withModelInterfaceThe
forwardmethod intorch_sim/models/fairchem.pymust use the same return annotation as the base interface to satisfy static checks.• File:
torch_sim/models/fairchem.py,forwarddefinition (around lines 324–341)
• Change return type fromdicttodict[str, torch.Tensor]- def forward(self, state: ts.SimState | StateDict) -> dict: + def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]:Also ensure every value in the returned dictionary is a
torch.Tensor.torch_sim/models/orb.py (1)
333-347: Fix model-path handling and isinstance union for portabilityTwo issues in init:
self.system_config = system_config or model.system_configruns before path loading; ifmodelis a str/Path, this fails.isinstance(model, str | Path)may not be portable; use a tuple for compatibility.- # Set up system configuration - self.system_config = system_config or model.system_config - self._max_num_neighbors = max_num_neighbors - self._edge_method = edge_method - self._half_supercell = half_supercell - self.conservative = conservative - - # Load model if path is provided - if isinstance(model, str | Path): - model = torch.load(model, map_location=self._device) + # Load model if path is provided + if isinstance(model, (str, Path)): + model = torch.load(model, map_location=self._device) + + # Set up system configuration (after model is resolved) + self.system_config = system_config or model.system_config + self._max_num_neighbors = max_num_neighbors + self._edge_method = edge_method + self._half_supercell = half_supercell + self.conservative = conservativetorch_sim/models/graphpes.py (1)
102-142: Override signature should match ModelInterface (add**kwargs)ModelInterface.forward accepts
**kwargs. Overriding without**kwargscan trigger basedpyright override-compatibility errors. Accept and ignore extra kwargs in the wrapper.Apply this diff to the method signature:
- def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: + def forward( + self, state: ts.SimState | StateDict, **kwargs: Any + ) -> dict[str, torch.Tensor]:Optional: instead of
# type: ignore[return-value], cast to the expected mapping to keep the file clean for basedpyright:- return self._gp_model.predict(atomic_graph, self._properties) # type: ignore[return-value] + return typing.cast( + dict[str, torch.Tensor], + self._gp_model.predict(atomic_graph, self._properties), + )torch_sim/models/metatomic.py (2)
51-62: Override signature should match ModelInterface (add**kwargs)ModelInterface.forward includes
**kwargs. This override omits it, which can fail type checking with basedpyright. Accept and ignore extra kwargs.- def forward( # noqa: C901, PLR0915 - self, - state: ts.SimState | StateDict, - ) -> dict[str, torch.Tensor]: + def forward( # noqa: C901, PLR0915 + self, + state: ts.SimState | StateDict, + **kwargs: Any, + ) -> dict[str, torch.Tensor]:
99-116:Pathhandling bug:.endswithused onPath
modelis typedstr | Path | Nonebut the implementation callsmodel.endswith(...). This will raise onPath. Normalize viaPath(model)when not using the literal"pet-mad".- if model == "pet-mad": + if model == "pet-mad": path = "https://huggingface.co/lab-cosmo/pet-mad/resolve/v1.1.0/models/pet-mad-v1.1.0.ckpt" self._model = load_model(path).export() - elif model.endswith(".ckpt"): - path = model - self._model = load_model(path).export() - elif model.endswith(".pt"): - path = model - self._model = load_atomistic_model(path, extensions_path) + else: + path_obj = Path(model) # type: ignore[arg-type] + if path_obj.suffix == ".ckpt": + self._model = load_model(str(path_obj)).export() + elif path_obj.suffix == ".pt": + self._model = load_atomistic_model(str(path_obj), extensions_path) + else: + raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".') - else: - raise ValueError('Model must be a path to a .ckpt/.pt file, or "pet-mad".')torch_sim/models/interface.py (1)
135-170: Ensure allforwardoverrides include**kwargsThe base
ModelInterface.forwarddeclares**kwargsto support model-specific options and satisfy LSP/type checking. The following subclasses currently defineforwardwithout**kwargsand should be updated:
- torch_sim/models/fairchem.py:324
- torch_sim/models/graphpes.py:177
- torch_sim/models/lennard_jones.py:356
- torch_sim/models/mace.py:240
- torch_sim/models/mattersim.py:113
- torch_sim/models/metatomic.py:151
- torch_sim/models/morse.py:346
- torch_sim/models/orb.py:361
- torch_sim/models/particle_life.py:213
- torch_sim/models/sevennet.py:152
- torch_sim/models/soft_sphere.py:380, 807
Please revise each signature, for example:
- def forward(self, state: SimState | StateDict) -> dict[str, torch.Tensor]: + def forward(self, state: SimState | StateDict, **kwargs) -> dict[str, torch.Tensor]:This ensures all overrides remain signature-compatible and maintain proper extensibility.
torch_sim/models/soft_sphere.py (5)
283-285: Fix wrong masses shape when constructing SimState from dict
torch.ones_like(state["positions"])produces shape [n_atoms, 3], but masses must be [n_atoms]. This will break both runtime and types.Apply:
- if isinstance(state, dict): - state = ts.SimState(**state, masses=torch.ones_like(state["positions"])) + if isinstance(state, dict): + masses = torch.ones( + state["positions"].shape[0], dtype=self.dtype, device=self.device + ) + state = ts.SimState(**state, masses=masses)
351-353: Guard against division by zero in force projectionIf any pair has zero distance (overlap),
(pair_forces / distances)will produce inf/NaN. Clamp/guard the denominator.- force_vectors = (pair_forces / distances)[:, None] * dr_vec + denom = torch.where(distances > 0, distances, torch.ones_like(distances)) + force_vectors = (pair_forces / denom)[:, None] * dr_vec
606-610: Breaks when species=None: cannot infer n_species
n_species = len(torch.unique(species))fails ifspecies is None. You either needspeciesor a way to infer n_species from provided matrices. Minimal fix: infer from any provided *_matrix; otherwise requirespecies.- self.species = species - n_species = len(torch.unique(species)) + self.species = species + if species is None: + if sigma_matrix is not None: + n_species = sigma_matrix.shape[0] + elif epsilon_matrix is not None: + n_species = epsilon_matrix.shape[0] + elif alpha_matrix is not None: + n_species = alpha_matrix.shape[0] + else: + raise ValueError( + "Provide `species` or at least one of {sigma_matrix, epsilon_matrix, alpha_matrix} to infer n_species." + ) + else: + n_species = int(torch.unique(species).numel())
711-718: Typo in vesin_nl_ts argument:sorti→sort_idThis will raise a TypeError at runtime.
- sorti=False, + sort_id=False,
870-876: Incorrect stacking/cat of results: double-processing forces/energies/stressesYou first stack 'forces', 'energies', 'stresses' and then immediately overwrite by concatenating them. Follow the same pattern as SoftSphereModel: stack only 'stress' and 'energy'; cat the rest.
- for key in ("stress", "energy", "forces", "energies", "stresses"): + for key in ("stress", "energy"): if key in properties: results[key] = torch.stack([out[key] for out in outputs])(The following cat loop stays as-is.)
torch_sim/integrators/npt.py (1)
1460-1499: Avoid constructing NPTNoseHooverState with momenta=None (type mismatch); initialize momenta and thermostat first
NPTNoseHooverStatedeclaresmomenta: torch.Tensor. PassingNonewill fail strict type checking (and can break runtime assumptions). Computemomentaearlier and pass it into the constructor. Also initialize the thermostat using the computed KE.- # Create initial state - npt_state = NPTNoseHooverState( + # Initialize particle momenta before building the state + momenta = kwargs.get( + "momenta", + calculate_momenta( + state.positions, state.masses, state.system_idx, kT, seed + ), + ) + + # Initialize thermostat using particle KE + KE = calc_kinetic_energy( + masses=state.masses, momenta=momenta, system_idx=state.system_idx + ) + thermostat = thermostat_fns.initialize(state.positions.numel(), KE, kT) + + # Create initial state with concrete tensors (no Nones) + npt_state = NPTNoseHooverState( positions=state.positions, - momenta=None, + momenta=momenta, energy=energy, forces=forces, masses=state.masses, atomic_numbers=atomic_numbers, cell=state.cell, pbc=state.pbc, system_idx=state.system_idx, reference_cell=reference_cell, cell_position=cell_position, cell_momentum=cell_momentum, cell_mass=cell_mass, barostat=barostat_fns.initialize(1, KE_cell, kT), - thermostat=None, + thermostat=thermostat, barostat_fns=barostat_fns, thermostat_fns=thermostat_fns, ) - - # Initialize momenta - momenta = kwargs.get( - "momenta", - calculate_momenta( - npt_state.positions, npt_state.masses, npt_state.system_idx, kT, seed - ), - ) - - # Initialize thermostat - npt_state.momenta = momenta - KE = calc_kinetic_energy( - momenta=npt_state.momenta, - masses=npt_state.masses, - system_idx=npt_state.system_idx, - ) - npt_state.thermostat = thermostat_fns.initialize( - npt_state.positions.numel(), KE, kT - )
♻️ Duplicate comments (1)
torch_sim/models/interface.py (1)
38-38: Base class now inherits nn.Module — agreedConsolidating nn.Module at the interface simplifies model inheritance and typing.
🧹 Nitpick comments (4)
torch_sim/models/particle_life.py (2)
51-69: Docstring says “Energies” but function returns forcesasymmetric_particle_pair_force_jit computes force magnitudes, not energies. Update the Returns description to prevent confusion and improve type-check clarity.
184-185: I’d like to inspect the full mapping block in torch_sim/models/particle_life.py to verify why it’s reversed. Could you run:rg -n "mapping = torch.stack" -A 10 torch_sim/models/particle_life.pytorch_sim/models/mattersim.py (1)
24-34: LGTM: Placeholder class now correctly derives from ModelInterfaceThis keeps typing consistent when imports fail and still raises eagerly in init. One nit: the docstring says “sevenn is not installed” — should be “mattersim”.
- It raises an ImportError if sevenn is not installed. + It raises an ImportError if MatterSim is not installed.torch_sim/models/metatomic.py (1)
79-97: Docstring vs implementation mismatch formodel=NonebehaviorDocstring says “If None, defaults to 'pet-mad'” but the code raises ValueError when
model is None. Please align behavior with docs, or update the docstring and type to removeNone.Two options:
- Update implementation to default to "pet-mad":
- if model is None: - raise ValueError( - "A model path, or the name of a pre-defined model, must be provided. " - 'Currently only "pet-mad" is available as a pre-defined model.' - ) + if model is None: + model = "pet-mad"
- Or update the docstring and annotation (remove
| None) to reflect that a value is required.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (45)
.github/PULL_REQUEST_TEMPLATE.md(1 hunks)examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py(1 hunks)examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py(1 hunks)examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py(2 hunks)examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py(1 hunks)examples/scripts/3_Dynamics/3.2_MACE_NVE.py(1 hunks)examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py(1 hunks)examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py(1 hunks)examples/scripts/4_High_level_api/4.1_high_level_api.py(1 hunks)examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py(3 hunks)examples/tutorials/high_level_tutorial.py(1 hunks)examples/tutorials/reporting_tutorial.py(2 hunks)tests/models/test_mattersim.py(3 hunks)tests/models/test_sevennet.py(3 hunks)tests/test_integrators.py(4 hunks)tests/test_monte_carlo.py(2 hunks)tests/test_optimizers.py(15 hunks)tests/test_runners.py(5 hunks)tests/test_state.py(1 hunks)tests/test_trajectory.py(2 hunks)tests/workflows/test_a2c.py(2 hunks)torch_sim/elastic.py(2 hunks)torch_sim/integrators/md.py(2 hunks)torch_sim/integrators/npt.py(8 hunks)torch_sim/integrators/nve.py(1 hunks)torch_sim/integrators/nvt.py(5 hunks)torch_sim/models/fairchem.py(2 hunks)torch_sim/models/graphpes.py(2 hunks)torch_sim/models/interface.py(2 hunks)torch_sim/models/lennard_jones.py(1 hunks)torch_sim/models/mace.py(2 hunks)torch_sim/models/mattersim.py(2 hunks)torch_sim/models/metatomic.py(2 hunks)torch_sim/models/morse.py(1 hunks)torch_sim/models/orb.py(2 hunks)torch_sim/models/particle_life.py(1 hunks)torch_sim/models/sevennet.py(2 hunks)torch_sim/models/soft_sphere.py(3 hunks)torch_sim/monte_carlo.py(2 hunks)torch_sim/optimizers.py(10 hunks)torch_sim/quantities.py(6 hunks)torch_sim/runners.py(1 hunks)torch_sim/state.py(10 hunks)torch_sim/trajectory.py(4 hunks)torch_sim/workflows/a2c.py(8 hunks)
✅ Files skipped from review due to trivial changes (16)
- examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py
- examples/scripts/4_High_level_api/4.1_high_level_api.py
- examples/scripts/3_Dynamics/3.2_MACE_NVE.py
- .github/PULL_REQUEST_TEMPLATE.md
- examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py
- examples/scripts/2_Structural_optimization/2.2_Soft_Sphere_FIRE.py
- torch_sim/monte_carlo.py
- torch_sim/integrators/md.py
- torch_sim/models/mace.py
- torch_sim/integrators/nve.py
- torch_sim/trajectory.py
- torch_sim/elastic.py
- tests/test_integrators.py
- examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py
- torch_sim/integrators/nvt.py
- torch_sim/optimizers.py
🧰 Additional context used
🧬 Code Graph Analysis (17)
tests/test_monte_carlo.py (2)
torch_sim/models/interface.py (1)
ModelInterface(38-170)tests/conftest.py (1)
lj_model(34-45)
examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (1)
torch_sim/quantities.py (1)
calc_kinetic_energy(101-134)
torch_sim/models/particle_life.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
tests/test_state.py (3)
torch_sim/integrators/nvt.py (1)
velocities(270-274)torch_sim/integrators/npt.py (1)
velocities(871-877)torch_sim/integrators/md.py (1)
velocities(45-49)
tests/workflows/test_a2c.py (1)
torch_sim/optimizers.py (2)
FireState(442-483)UnitCellFireState(621-692)
torch_sim/models/sevennet.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
torch_sim/models/fairchem.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
torch_sim/models/lennard_jones.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
torch_sim/models/mattersim.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
torch_sim/models/morse.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
torch_sim/models/graphpes.py (2)
torch_sim/models/interface.py (1)
ModelInterface(38-170)tests/models/test_graphpes.py (2)
ts_nequip_model(128-134)ts_mace_model(155-161)
tests/test_trajectory.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
torch_sim/models/metatomic.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
torch_sim/models/orb.py (1)
torch_sim/models/interface.py (1)
ModelInterface(38-170)
examples/tutorials/reporting_tutorial.py (2)
torch_sim/models/interface.py (1)
ModelInterface(38-170)torch_sim/state.py (1)
SimState(26-402)
torch_sim/quantities.py (4)
torch_sim/integrators/npt.py (2)
momenta(71-73)velocities(871-877)torch_sim/integrators/nvt.py (1)
velocities(270-274)torch_sim/integrators/md.py (1)
velocities(45-49)torch_sim/units.py (1)
MetalUnits(83-103)
torch_sim/workflows/a2c.py (2)
torch_sim/models/interface.py (1)
ModelInterface(38-170)torch_sim/optimizers.py (2)
FireState(442-483)UnitCellFireState(621-692)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (11)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (macos-14, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (macos-14, 3.11, highest, mattersim, tests/models/test_mattersim.py)
- GitHub Check: test-model (macos-14, 3.11, highest, metatomic, tests/models/test_metatomic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: build-docs
🔇 Additional comments (58)
examples/scripts/3_Dynamics/3.1_Lennard_Jones_NVE.py (1)
81-81: LGTM: purely a formatting consolidation of SimState kwargsNo functional changes; kwargs are correct and consistent with surrounding examples.
tests/workflows/test_a2c.py (3)
1-2: LGTM! Appropriate import addition.The
castimport from typing is correctly added to support the type casting needed later in the test.
9-9: LGTM! Correct import addition.Adding
FireStateimport is necessary for the type cast used later in the test function.
160-160: LGTM! Proper handling of ambiguous return type.The cast to
FireStatecorrectly handles the fact thatrandom_packed_structurecan return either aFireStateor a tuple. This ensures type safety while maintaining test functionality.torch_sim/quantities.py (5)
3-4: LGTM! Necessary import addition.The
castimport is correctly added to support type casting in the function implementations.
25-31: LGTM! Excellent improvement to argument handling.Making arguments keyword-only with the leading
*and makingmomentaoptional improves the API design. This enforces clearer calling conventions and prevents positional argument mistakes.
44-50: LGTM! Improved validation logic.The XOR validation using
((momenta is not None) ^ (velocities is not None))is more concise and clearer than separate checks. The explicit cast totorch.Tensorwhen velocities is guaranteed to be non-None ensures type safety.
73-98: LGTM! Consistent parameter improvements.The same improvements applied to
calc_temperature: keyword-only arguments, optional momenta parameter, and explicit keyword arguments in thecalc_kTcall. TheMetalUnitstype annotation is more precise than the previous generic object type.
101-122: LGTM! Consistent implementation pattern.The
calc_kinetic_energyfunction follows the same improved pattern with keyword-only arguments, optional momenta parameter, and consolidated XOR validation logic.torch_sim/state.py (7)
11-11: LGTM! Appropriate imports for enhanced type safety.The additional imports support the generic type improvements and explicit casting used throughout the file.
112-147: LGTM! Fixed constructor validation logic.The system_idx validation and assignment logic is now more robust. The validation properly handles the case where system_idx is provided, and the assignment is correctly performed after validation.
254-259: LGTM! Proper error handling for volume property.Raising a
ValueErrorfor non-PBC systems is the correct approach since volume is undefined without periodic boundaries. This is better than returningNonewhich could lead to silent errors downstream.
368-368: LGTM! Explicit type casting for method return.The cast ensures the return type matches the method signature while maintaining type safety.
408-420: LGTM! Simplified mixin interface.Removing the momenta property and adding the required attributes simplifies the mixin interface. The reference_row_vector_cell property provides a clean way to access the reference cell in row vector format.
483-490: LGTM! Enhanced type safety with generics.The SimStateT generic type variable and its usage in function signatures preserve subclass types through operations like device conversion, which is excellent for type safety.
972-972: LGTM! Appropriate type casting for validation.The cast ensures type safety when accessing the n_systems property on list elements during validation.
examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py (1)
100-100: LGTM! Consistent argument ordering.Moving
pbc=Trueafteratomic_numbersmaintains consistency with other example scripts while preserving functionality.torch_sim/models/morse.py (1)
115-115: LGTM! Consistent interface inheritance.Removing
torch.nn.Moduleinheritance aligns with the architectural change whereModelInterfacenow provides the PyTorch module functionality. This centralizes the inheritance pattern and reduces duplication across model implementations.torch_sim/models/particle_life.py (2)
86-86: Inheritance change aligns with new interface hierarchyGood move inheriting only from ModelInterface now that it subclasses nn.Module. super().init() is present.
155-161: Ignore CUDA scalar conversion suggestionThe
vesin_nl_tssignature clearly declaresdef vesin_nl_ts( …, cutoff: torch.Tensor, … )so it expects a tensor (not a Python float). Converting
self.cutoffto a Python float—even with.item()—would break the intended API. You can leave the call as-is.Likely an incorrect or invalid review comment.
torch_sim/models/lennard_jones.py (2)
122-122: Inheritance change matches new designSwitching to inherit only from ModelInterface (which is an nn.Module) is consistent with the refactor.
265-271: Ignore float conversion suggestion for cutoff
Thevesin_nl_tsfunction is scripted to accept a 0-dimtorch.Tensorand internally invokescutoff.item(), so passingself.cutoff(even on CUDA) is correct. Converting it to a Python float would mismatch the declaredtorch.Tensorparameter and break the JIT call.Likely an incorrect or invalid review comment.
examples/tutorials/high_level_tutorial.py (1)
135-136: Keyword-only call is correctSwitch to keyword args for calc_kinetic_energy matches the updated signature and improves clarity.
tests/test_monte_carlo.py (2)
6-6: Type import aligns test with new model interfaceImporting ModelInterface keeps tests consistent with the refactor.
116-117: Fixture type updated to ModelInterface is appropriateTyping lj_model as ModelInterface matches how models are consumed across the codebase.
torch_sim/models/fairchem.py (2)
48-56: Placeholder class now inherits ModelInterfaceGood change; keeps type checkers happy even when FairChem isn’t installed. Raising in init is fine here.
73-73: Main class inheritance updated appropriatelyInheriting only from ModelInterface centralizes nn.Module via the interface. super().init() is correctly called in init.
tests/test_trajectory.py (2)
12-12: ModelInterface import aligns the test with the new model abstractionGood change. This keeps the test consistent with ModelInterface-based typing across the codebase.
752-756: Type-narrowing to ModelInterface for energy_calculator — LGTMAnnotating the model parameter as ModelInterface is correct here since the calculator calls model(state) and accesses "energy".
examples/tutorials/reporting_tutorial.py (2)
209-209: Importing ModelInterface for tutorial examples — good consistencyThis ensures the tutorial code type-checks under basedpyright with the new interface.
218-221: Update calculate_energy to accept ModelInterface — correct and minimalAccurately reflects the contract that models implement ModelInterface and return an output dict with "energy".
examples/scripts/3_Dynamics/3.11_Lennard_Jones_NPT_Langevin.py (2)
101-102: Keyword argument order tweak for SimState — OKReordering pbc after atomic_numbers using keyword args is harmless and reads cleanly.
151-156: Keyword-only kinetic energy call and variable rename are correct
- calc_kinetic_energy now requires keyword args; usage is correct.
- Passing kinetic_energy to get_pressure matches the expected signature (stress, kinetic_energy, volume).
tests/test_state.py (1)
503-505: Require velocities and reference_cell in DeformState — good alignmentMaking these parameters required reflects the updated mixin expectations and clarifies test intent.
torch_sim/models/mattersim.py (1)
42-55: LGTM: Inherit only from ModelInterfaceGiven
ModelInterfacealready subclassestorch.nn.Module, this simplifies the hierarchy without functional changes.super().__init__()still initializesnn.Modulestate.torch_sim/models/orb.py (2)
42-52: LGTM: Placeholder class now uses ModelInterfaceConsistent with the new interface-first design and keeps the eager ImportError behavior.
250-274: LGTM: Inherit only from ModelInterfaceMatches the project-wide refactor and maintains proper
nn.Moduleinitialization viasuper().__init__().torch_sim/models/sevennet.py (2)
35-45: LGTM: Placeholder class inherits ModelInterfaceKeeps error surfacing while maintaining consistent typing.
47-59: LGTM: Inherit only from ModelInterfaceConsistent with project-wide refactor;
nn.Moduleis still initialized viasuper().__init__().torch_sim/models/graphpes.py (1)
39-49: ImportError fallback: inheritance change is correctUsing ModelInterface alone (dropping direct nn.Module) in the stub keeps the hierarchy consistent now that ModelInterface subclasses nn.Module.
torch_sim/models/metatomic.py (1)
39-49: ImportError fallback: inheritance change is correctUsing ModelInterface alone (dropping direct nn.Module) in the stub keeps the hierarchy consistent.
tests/test_runners.py (1)
26-30: Keyword-arg calls tocalc_kinetic_energyare correctSwitching to
momenta=..., masses=...improves clarity and typing; matches the current API and helps basedpyright.Also applies to: 63-67, 119-123, 170-174
torch_sim/models/interface.py (1)
72-75: Declaring protected attributes at class level is the right moveThese declarations unblock basedpyright by guaranteeing attribute presence on subclasses.
torch_sim/models/soft_sphere.py (3)
438-438: Multi-model now inherits ModelInterface — good alignment with the refactorNo issues with the base class swap.
597-602: Private attribute renaming is fine; confirm public properties back theseSetting
_device/_dtype/_compute_*is consistent with SoftSphereModel. As above, ensure ModelInterface exposes the publicdevice/dtype/compute_*properties used elsewhere.Same script as in Line 133 comment can be used to confirm the properties exist.
133-133: ModelInterface defines all required properties
Verified thatModelInterfacedeclares the following@propertymethods intorch_sim/models/interface.py:
device(l. 77)dtype(l. 89)compute_stress(l. 101)compute_forces(l. 113)No further changes needed.
torch_sim/integrators/npt.py (2)
17-17: Importing ModelInterface for type annotations — goodThis aligns the integrators with the new interface abstraction.
1436-1436: Consistent keyword-only calls to calc_kinetic_energy — LGTMAll updated calls now use
masses=/momenta=and passsystem_idxwhere per-system results are needed. This matches the API change and improves clarity.Also applies to: 1491-1494, 1550-1552, 1555-1555, 1607-1608
torch_sim/workflows/a2c.py (8)
22-22: LGTM! Import added for ModelInterface.The import is correctly added to support the type annotation changes throughout the file.
232-232: LGTM! Return type annotation correctly reflects conditional behavior.The union type
FireState | tuple[FireState, list[np.ndarray]]accurately represents that the function returns either aFireStateor a tuple based on whetherlogis provided.
580-597: LGTM! Performance optimization for composition handling.The refactoring to convert
restrict_to_compositionsto a set upfront and use.add()instead of list operations is more efficient. The explicit type annotations also improve type safety.
614-614: LGTM! Explicit type annotation improves type safety.The explicit typing of
candidatesaslist[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]makes the return type clear and helps with type checking.
709-709: LGTM! Parameter type updated to ModelInterface.The change from
torch.nn.ModuletoModelInterfaceis consistent with the codebase refactoring and provides more specific typing for models in the simulation framework.
711-711: LGTM! Return type expanded to include energy and pressure lists.The return type correctly reflects that the function now returns final energies and pressures as lists of floats in addition to the state and logger. This is consistent with the function implementation that prints and returns these values.
755-756: LGTM! Type annotations updated for consistency.The parameter and return type annotations are properly updated to use the new type system with explicit
UnitCellFireStateanddict[str, torch.Tensor]types.
330-331: Missing return statement in else branch.The function has a conditional return at line 329-330 but is missing a return statement for the case when
log is None. This will cause a runtime error.Apply this diff to fix the missing return:
if log is not None: return state, log - + return stateLikely an incorrect or invalid review comment.
tests/test_optimizers.py (2)
9-9: LGTM! Import added for ModelInterface.The import is correctly added to support the updated type annotations throughout the test file.
27-27: LGTM! Model parameter types updated to ModelInterface.All test functions now correctly use
ModelInterfaceinstead oftorch.nn.Modulefor thelj_modelparameter. This is consistent with the codebase refactoring where models inherit fromModelInterfacerather than directly fromtorch.nn.Module.Also applies to: 66-66, 115-115, 189-189, 205-205, 213-213, 276-276, 329-329, 418-418, 452-452, 471-471, 507-507, 596-596, 711-711, 787-787
| 1: { | ||
| "ke": lambda state: calc_kinetic_energy( | ||
| velocities=state.velocities, masses=state.masses | ||
| ) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Verify velocities usage in FIRE reporter
Using velocities=state.velocities is fine if SimState guarantees velocities during FIRE optimization and calc_kinetic_energy supports it. Please sanity-check both assumptions.
Run this script to verify the function signature and SimState fields:
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Check calc_kinetic_energy signature:"
rg -n "def calc_kinetic_energy\\(" -A 5 torch_sim | sed -n '1,12p'
echo
echo "Confirm it accepts 'velocities' and/or 'momenta':"
rg -n "calc_kinetic_energy\\(" -A 2 torch_sim | rg -n "velocities|momenta|masses" -n
echo
echo "Check that SimState defines 'velocities' in contexts used by FIRE:"
rg -n "class SimState" -n torch_sim
rg -n "velocities" torch_sim | head -n 20Length of output: 4095
Ensure SimState provides velocities for the FIRE reporter
While calc_kinetic_energy does accept a velocities tensor (see torch_sim/quantities.py:def calc_kinetic_energy signature), SimState does not currently define or populate a velocities attribute:
• torch_sim/quantities.py (lines 101–107): supports velocities: Tensor | None
• torch_sim/state.py (line 966): TODO to pass velocities from external engines, no field defined
• torch_sim/runners.py (lines 49, 52): reporter calls velocities=state.velocities
Please address this by either:
- Adding/populating a
velocitiesfield on SimState during FIRE optimizations, or - Switching the reporter to use
momenta=state.momenta(which SimState does define)
🤖 Prompt for AI Agents
In tests/test_runners.py around lines 365 to 369, the reporter calls
calc_kinetic_energy with state.velocities, but SimState does not define or
populate a velocities attribute. To fix this, either add and populate a
velocities field on SimState during FIRE optimizations so it can be passed
correctly, or modify the reporter to use momenta=state.momenta instead, since
momenta is already defined on SimState.
| "kinetic_energy": lambda state: calc_kinetic_energy( | ||
| velocities=state.velocities, masses=state.masses | ||
| ), | ||
| "temperature": lambda state: calc_kT( | ||
| velocities=state.velocities, masses=state.masses | ||
| ), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Kinetic/temperature calculators now depend on velocities — verify availability or guard
Switch to keyword args looks good and aligns with API. However, if any integrator/state doesn’t populate state.velocities, these lambdas will raise at runtime. Either ensure all integrate paths guarantee velocities, or guard the calculators (e.g., skip or return NaN when velocities are absent).
You can quickly sanity-check for residual usage of momenta vs. assured velocities support with:
🏁 Script executed:
#!/bin/bash
# Inspect SimState and integrators for velocities/momenta usage
rg -n "velocities" torch_sim | head -n 50
rg -n "momenta" torch_sim | head -n 50
# Check if SimState defines 'velocities' as required attr
rg -n "class SimState" -A 40 torch_sim/state.pyLength of output: 11438
🏁 Script executed:
#!/bin/bash
# Inspect runner reporters setup for state.velocities usage
rg -n "reporters" -A30 torch_sim/runners.py
# Inspect how Trajectory uses save_velocities and when velocities are attached
rg -n "class Trajectory" -A200 torch_sim/trajectory.py
# Check save_velocities flag defaults and how velocities propagate into state
rg -n "save_velocities" -A20 torch_sim/trajectory.pyLength of output: 15077
🏁 Script executed:
#!/bin/bash
# Check for StaticState definition in the codebase
rg -n "class StaticState" -A20 torch_sim/state.pyLength of output: 51
🏁 Script executed:
#!/bin/bash
# Locate StaticState definition across the repo
rg -n "class StaticState" -A20 .
# Locate MDState definition across the repo
rg -n "class MDState" -A20 .
# Show the top of runners.py (including default prop_calculators)
sed -n '1,100p' torch_sim/runners.pyLength of output: 6753
Guard kinetic/temperature reporters against missing velocities
In _configure_reporter (torch_sim/runners.py), the default kinetic_energy and temperature lambdas assume state.velocities always exists. While MDState provides a velocities property, base SimState and static‐optimization states do not, leading to an AttributeError if users request these props outside an MD run.
• Location:
- torch_sim/runners.py, around lines 46–53 in
_configure_reporter
• Suggested change: wrap each calculator in a guard that checks forstate.velocitiesand returns NaN (or skips) when unavailable
Example diff:
--- a/torch_sim/runners.py
+++ b/torch_sim/runners.py
@@ def _configure_reporter(...):
- possible_properties = {
- "kinetic_energy": lambda state: calc_kinetic_energy(
- velocities=state.velocities, masses=state.masses
- ),
- "temperature": lambda state: calc_kT(
- velocities=state.velocities, masses=state.masses
- ),
+ possible_properties = {
+ "kinetic_energy": lambda state: (
+ calc_kinetic_energy(velocities=state.velocities, masses=state.masses)
+ if getattr(state, "velocities", None) is not None
+ else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
+ ),
+ "temperature": lambda state: (
+ calc_kT(velocities=state.velocities, masses=state.masses)
+ if getattr(state, "velocities", None) is not None
+ else torch.tensor(float("nan"), device=state.device, dtype=state.dtype)
+ ),This ensures reporters won’t crash when running on states without velocities.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "kinetic_energy": lambda state: calc_kinetic_energy( | |
| velocities=state.velocities, masses=state.masses | |
| ), | |
| "temperature": lambda state: calc_kT( | |
| velocities=state.velocities, masses=state.masses | |
| ), | |
| possible_properties = { | |
| "kinetic_energy": lambda state: ( | |
| calc_kinetic_energy(velocities=state.velocities, masses=state.masses) | |
| if getattr(state, "velocities", None) is not None | |
| else torch.tensor(float("nan"), device=state.device, dtype=state.dtype) | |
| ), | |
| "temperature": lambda state: ( | |
| calc_kT(velocities=state.velocities, masses=state.masses) | |
| if getattr(state, "velocities", None) is not None | |
| else torch.tensor(float("nan"), device=state.device, dtype=state.dtype) | |
| ), | |
| # … other properties … | |
| } |
🤖 Prompt for AI Agents
In torch_sim/runners.py around lines 48 to 53, the lambdas for "kinetic_energy"
and "temperature" assume state.velocities always exists, which can cause
AttributeError for states without velocities. Modify these lambdas to first
check if state.velocities is present; if not, return NaN to prevent crashes.
This guard ensures the reporters handle states lacking velocities gracefully.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is a larger issue. we'll prob fix this in another PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@curtischong, that makes perfect sense! Keeping the scope focused on type checking compliance for this PR is a good approach. The runtime safety issue with missing velocities can definitely be addressed separately.
Would you like me to create a GitHub issue to track this for the future PR? It would help ensure the velocities availability issue doesn't get forgotten.
✏️ Learnings added
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
[skip ci] Signed-off-by: Rhys Goodall <[email protected]>
| torch.Tensor: Scalar temperature value | ||
| """ | ||
| if momenta is not None and velocities is not None: | ||
| if not ((momenta is not None) ^ (velocities is not None)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the value error is slightly misleading. if both of them are none then the xor will evaluate to false. but the not will negate it and the entire statement will be true.
so if they pass in no velocity and no mementia, the value error will raise. but the "not both" is misleading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a minor nit though. Oh I get it. you merged the if statements. How about: "Must pass in one of either momenta or velocities"
| If system_idx is provided: Tensor of kinetic energies per system | ||
| """ | ||
| if momenta is not None and velocities is not None: | ||
| if not ((momenta is not None) ^ (velocities is not None)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe update this value error as well?
uvx basedpyrightThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/test_quantities.py (2)
3-3: Use public API for tensor type annotation.Importing from
torch._tensoraccesses a private module. Usetorch.Tensordirectly for better compatibility with type checkers.-from torch._tensor import Tensor +Then update the type annotations to use
torch.Tensorinstead ofTensor:-def single_system_data() -> dict[str, Tensor]: +def single_system_data() -> dict[str, torch.Tensor]:-def batched_system_data() -> dict[str, Tensor]: +def batched_system_data() -> dict[str, torch.Tensor]:-def test_calc_kinetic_energy_single_system(single_system_data: dict[str, Tensor]) -> None: +def test_calc_kinetic_energy_single_system(single_system_data: dict[str, torch.Tensor]) -> None:And similarly for other function signatures.
127-137: Consider adding batched system test coverage.The temperature calculation test correctly verifies the unit conversion relationship with kT. Consider adding a test case for batched systems to match the coverage of other test functions.
def test_calc_temperature_batched_system(batched_system_data: dict[str, torch.Tensor]) -> None: temp = quantities.calc_temperature( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], ) kt = quantities.calc_kT( masses=batched_system_data["masses"], velocities=batched_system_data["velocities"], system_idx=batched_system_data["system_idx"], ) assert torch.allclose(temp, kt / MetalUnits.temperature)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tests/test_quantities.py(1 hunks)torch_sim/quantities.py(6 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- torch_sim/quantities.py
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: curtischong
PR: Radical-AI/torch-sim#215
File: torch_sim/runners.py:48-53
Timestamp: 2025-08-08T02:04:31.517Z
Learning: In the torch-sim project, when reviewing PRs, scope expansion issues (like runtime safety improvements) that are outside the current PR's main objective should be acknowledged but deferred to separate PRs rather than insisting on immediate fixes.
🧬 Code Graph Analysis (1)
tests/test_quantities.py (2)
torch_sim/units.py (1)
MetalUnits(83-103)torch_sim/quantities.py (3)
calc_kinetic_energy(101-134)calc_kT(25-70)calc_temperature(73-97)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (58)
- GitHub Check: test-examples (examples/scripts/7_Others/7.3_Batched_neighbor_list.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.4_Velocity_AutoCorrelation.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.6_Compare_ASE_to_VV_FIRE.py)
- GitHub Check: test-examples (examples/scripts/7_Others/7.1_Soft_sphere_autograd.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.3_Fairchem.py)
- GitHub Check: test-examples (examples/scripts/1_Introduction/1.2_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.2_QuasiHarmonic_MACE.py)
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.1_Phonons_MACE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.1_Lennard_Jones_FIRE.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.5_MACE_UnitCellFilter_Gradient_Descen...
- GitHub Check: test-examples (examples/scripts/6_Phonons/6.3_Conductivity_MACE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.12_MACE_NPT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.3_MACE_Gradient_Descent.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.7_MACE_FrechetCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.7_Lennard_Jones_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.4_MACE_FIRE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.9_MACE_NVT_staggered_stress.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.2_In_Flight_WBM.py)
- GitHub Check: test-examples (examples/scripts/2_Structural_optimization/2.6_MACE_UnitCellFilter_FIRE.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.1_high_level_api.py)
- GitHub Check: test-examples (examples/scripts/4_High_level_api/4.2_auto_batching_api.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.3_Elastic.py)
- GitHub Check: test-examples (examples/scripts/5_Workflow/5.1_a2c_silicon_batched.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.8_MACE_NPT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.10_Hybrid_swap_mc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.4_MACE_NVT_Langevin.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.6_MACE_NVT_Nose_Hoover_temp_profile.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.3_MACE_NVE_cueq.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.2_MACE_NVE.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.13_MACE_NVE_non_pbc.py)
- GitHub Check: test-examples (examples/scripts/3_Dynamics/3.5_MACE_NVT_Nose_Hoover.py)
- GitHub Check: test-examples (examples/tutorials/hybrid_swap_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/low_level_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/state_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/reporting_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/autobatching_tutorial.py)
- GitHub Check: test-examples (examples/tutorials/high_level_tutorial.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (macos-14, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.12, lowest-direct, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, sevenn, tests/models/test_sevennet.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, orb, tests/models/test_orb.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_optimizers_vs_ase.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/models/test_mace.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, mace, tests/test_elastic.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, graphpes, tests/models/test_graphpes.py)
- GitHub Check: test-model (ubuntu-latest, 3.11, highest, fairchem, tests/models/test_fairchem.py)
- GitHub Check: test-core (macos-14, 3.12, lowest-direct)
- GitHub Check: test-core (ubuntu-latest, 3.11, highest)
- GitHub Check: test-core (ubuntu-latest, 3.12, lowest-direct)
- GitHub Check: test-core (macos-14, 3.11, highest)
- GitHub Check: build-docs
🔇 Additional comments (4)
tests/test_quantities.py (4)
13-26: LGTM!The fixture provides accurate test data. The expected kinetic energy (13.5) and kT (4.5) values are mathematically correct for the given masses and velocities.
29-44: LGTM!The batched system fixture provides accurate test data. The expected kinetic energies ([3.0, 24.0]) and kT values ([1.0, 8.0]) per system are mathematically correct.
47-92: Excellent test coverage for kinetic energy calculations.The tests comprehensively cover:
- Single and batched system scenarios
- Both velocity and momentum input modes
- Proper error handling for invalid argument combinations
- Correct usage of keyword-only arguments
94-125: LGTM!The kT calculation tests provide comprehensive coverage matching the kinetic energy test structure, correctly testing both single and batched systems with velocity and momentum inputs.
curtischong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
overall I approve - 2 minor nits you might want to change before we merge
Summary
By ensuring that basedpyright runs, we reduce the chance of bugs in our code
Checklist
Before a pull request can be merged, the following items must be checked:
Run ruff on your code.
We highly recommended installing the pre-commit hooks running in CI locally to speedup the development process. Simply run
pip install pre-commit && pre-commit installto install the hooks which will check your code before each commit.Summary by CodeRabbit
Summary by CodeRabbit